Submitted by:
| # | Name | Id | |
|---|---|---|---|
| Student 1 | [Jamal Tannous] | [208912337] | [jamaltannous@campus.technion.ac.il] |
| Student 2 | [Snir Hordan] | [205689581] | [snirhordan@campus.technion.ac.il] |
In this assignment we'll create a from-scratch implementation of two fundemental deep learning concepts: the backpropagation algorithm and stochastic gradient descent-based optimizers. Following that, we'll focus on sequences, and learn to generate text with a deep multilayer RNN network based on GRU cells.
hw1, hw2, etc).
You can of course use any editor or IDE to work on these files.In this part we will learn about working with text sequences using recurrent neural networks. We'll go from a raw text file all the way to a fully trained GRU-RNN model and generate works of art!
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda
Let's begin by downloading a corpus containing all the works of William Shakespeare. Since he was very prolific, this corpus is fairly large and will provide us with enough data for obtaining impressive results.
CORPUS_URL = 'https://github.com/cedricdeboom/character-level-rnn-datasets/raw/master/datasets/shakespeare.txt'
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
def download_corpus(out_path=DATA_DIR, url=CORPUS_URL, force=False):
pathlib.Path(out_path).mkdir(exist_ok=True)
out_filename = os.path.join(out_path, os.path.basename(url))
if os.path.isfile(out_filename) and not force:
print(f'Corpus file {out_filename} exists, skipping download.')
else:
print(f'Downloading {url}...')
with urllib.request.urlopen(url) as response, open(out_filename, 'wb') as out_file:
shutil.copyfileobj(response, out_file)
print(f'Saved to {out_filename}.')
return out_filename
corpus_path = download_corpus()
Corpus file /home/snirhordan/.pytorch-datasets/shakespeare.txt exists, skipping download.
Load the text into memory and print a snippet:
with open(corpus_path, 'r', encoding='utf-8') as f:
corpus = f.read()
print(f'Corpus length: {len(corpus)} chars')
print(corpus[7:1234])
Corpus length: 6347703 chars
ALLS WELL THAT ENDS WELL
by William Shakespeare
Dramatis Personae
KING OF FRANCE
THE DUKE OF FLORENCE
BERTRAM, Count of Rousillon
LAFEU, an old lord
PAROLLES, a follower of Bertram
TWO FRENCH LORDS, serving with Bertram
STEWARD, Servant to the Countess of Rousillon
LAVACHE, a clown and Servant to the Countess of Rousillon
A PAGE, Servant to the Countess of Rousillon
COUNTESS OF ROUSILLON, mother to Bertram
HELENA, a gentlewoman protected by the Countess
A WIDOW OF FLORENCE.
DIANA, daughter to the Widow
VIOLENTA, neighbour and friend to the Widow
MARIANA, neighbour and friend to the Widow
Lords, Officers, Soldiers, etc., French and Florentine
SCENE:
Rousillon; Paris; Florence; Marseilles
ACT I. SCENE 1.
Rousillon. The COUNT'S palace
Enter BERTRAM, the COUNTESS OF ROUSILLON, HELENA, and LAFEU, all in black
COUNTESS. In delivering my son from me, I bury a second husband.
BERTRAM. And I in going, madam, weep o'er my father's death anew;
but I must attend his Majesty's command, to whom I am now in
ward, evermore in subjection.
LAFEU. You shall find of the King a husband, madam; you, sir, a
father. He that so generally is at all times good must of
The first thing we'll need is to map from each unique character in the corpus to an index that will represent it in our learning process.
TODO: Implement the char_maps() function in the hw3/charnn.py module.
import hw3.charnn as charnn
char_to_idx, idx_to_char = charnn.char_maps(corpus)
print(char_to_idx)
test.assertEqual(len(char_to_idx), len(idx_to_char))
test.assertSequenceEqual(list(char_to_idx.keys()), list(idx_to_char.values()))
test.assertSequenceEqual(list(char_to_idx.values()), list(idx_to_char.keys()))
{'f': 0, '6': 1, 'I': 2, 'W': 3, 'X': 4, 'B': 5, ']': 6, 'C': 7, 'D': 8, '"': 9, '&': 10, '[': 11, ':': 12, 'n': 13, 'i': 14, 't': 15, '\n': 16, 'c': 17, 'F': 18, 'm': 19, 'g': 20, 'e': 21, 'A': 22, 'h': 23, '}': 24, 'v': 25, 'x': 26, '(': 27, 'w': 28, ';': 29, 'J': 30, 'K': 31, '!': 32, '_': 33, '4': 34, 'Y': 35, '1': 36, 'T': 37, 'R': 38, 'u': 39, 'o': 40, ' ': 41, 'U': 42, 'M': 43, ')': 44, 'l': 45, '8': 46, ',': 47, 'd': 48, 'L': 49, '$': 50, 'S': 51, 'b': 52, 'r': 53, 'y': 54, 'q': 55, '3': 56, 'Z': 57, '.': 58, '\ufeff': 59, '7': 60, '5': 61, '-': 62, "'": 63, '<': 64, 'z': 65, 'H': 66, 'V': 67, 'Q': 68, 'G': 69, 'j': 70, 'O': 71, '?': 72, 'k': 73, 'E': 74, '0': 75, '9': 76, 'N': 77, 'a': 78, 'P': 79, 'p': 80, '2': 81, 's': 82}
Seems we have some strange characters in the corpus that are very rare and are probably due to mistakes. To reduce the length of each tensor we'll need to later represent our chars, it's best to remove them.
TODO: Implement the remove_chars() function in the hw3/charnn.py module.
corpus, n_removed = charnn.remove_chars(corpus, ['}','$','_','<','\ufeff'])
print(f'Removed {n_removed} chars')
# After removing the chars, re-create the mappings
char_to_idx, idx_to_char = charnn.char_maps(corpus)
Removed 34 chars
The next thing we need is an embedding of the chracters.
An embedding is a representation of each token from the sequence as a tensor.
For a char-level RNN, our tokens will be chars and we can thus use the simplest possible embedding: encode each char as a one-hot tensor. In other words, each char will be represented
as a tensor whos length is the total number of unique chars (V) which contains all zeros except at the index
corresponding to that specific char.
TODO: Implement the functions chars_to_onehot() and onehot_to_chars() in the hw3/charnn.py module.
# Wrap the actual embedding functions for calling convenience
def embed(text):
return charnn.chars_to_onehot(text, char_to_idx)
def unembed(embedding):
return charnn.onehot_to_chars(embedding, idx_to_char)
text_snippet = corpus[3104:3148]
print(text_snippet)
print(embed(text_snippet[0:3]))
test.assertEqual(text_snippet, unembed(embed(text_snippet)))
test.assertEqual(embed(text_snippet).dtype, torch.int8)
brine a maiden can season her praise in.
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0]], dtype=torch.int8)
We wish to train our model to generate text by constantly predicting what the next char should be based on the past. To that end we'll need to train our recurrent network in a way similar to a classification task. At each timestep, we input a char and set the expected output (label) to be the next char in the original sequence.
We will split our corpus into shorter sequences of length S chars (see question below).
Each sample we provide our model with will therefore be a tensor of shape (S,V) where V is the embedding dimension. Our model will operate sequentially on each char in the sequence.
For each sample, we'll also need a label. This is simply another sequence, shifted by one char so that the label of each char is the next char in the corpus.
TODO: Implement the chars_to_labelled_samples() function in the hw3/charnn.py module.
# Create dataset of sequences
seq_len = 64
vocab_len = len(char_to_idx)
# Create labelled samples
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
print(f'samples shape: {samples.shape}')
print(f'labels shape: {labels.shape}')
# Test shapes
num_samples = (len(corpus) - 1) // seq_len
test.assertEqual(samples.shape, (num_samples, seq_len, vocab_len))
test.assertEqual(labels.shape, (num_samples, seq_len))
# Test content
for _ in range(1000):
# random sample
i = np.random.randint(num_samples, size=(1,))[0]
# Compare to corpus
test.assertEqual(unembed(samples[i]), corpus[i*seq_len:(i+1)*seq_len], msg=f"content mismatch in sample {i}")
# Compare to labels
sample_text = unembed(samples[i])
label_text = str.join('', [idx_to_char[j.item()] for j in labels[i]])
test.assertEqual(sample_text[1:], label_text[0:-1], msg=f"label mismatch in sample {i}")
samples shape: torch.Size([99182, 64, 78]) labels shape: torch.Size([99182, 64])
Let's print a few consecutive samples. You should see that the text continues between them.
import re
import random
i = random.randrange(num_samples-5)
for i in range(i, i+5):
test.assertEqual(len(samples[i]), seq_len)
s = re.sub(r'\s+', ' ', unembed(samples[i])).strip()
print(f'sample [{i}]:\n\t{s}')
sample [7649]: o be ballast at her nose. ANTIPHOLUS OF SYRACUSE. Where stood sample [7650]: Belgia, the Netherlands? DROMIO OF SYRACUSE. O, Sir, I did not l sample [7651]: ook so low. To conclude: this drudge or diviner laid claim to sample [7652]: me; call'd me Dromio; swore I was assur'd to her; told me what sample [7653]: privy marks I had about me, as, the mark of my shoulder, the
As usual, instead of feeding one sample at a time into our model's forward we'll work with batches of samples. This means that at every timestep, our model will operate on a batch of chars that are from different sequences. Effectively this will allow us to parallelize training our model by dong matrix-matrix multiplications instead of matrix-vector during the forward pass.
An important nuance is that we need the batches to be contiguous, i.e. sample $k$ in batch $j$ should continue sample $k$ from batch $j-1$. The following figure illustrates this:

If we naïvely take consecutive samples into batches, e.g. [0,1,...,B-1], [B,B+1,...,2B-1] and so on, we won't have contiguous
sequences at the same index between adjacent batches.
To accomplish this we need to tell our DataLoader which samples to combine together into one batch.
We do this by implementing a custom PyTorch Sampler, and providing it to our DataLoader.
TODO: Implement the SequenceBatchSampler class in the hw3/charnn.py module.
from hw3.charnn import SequenceBatchSampler
sampler = SequenceBatchSampler(dataset=range(32), batch_size=10)
sampler_idx = list(sampler)
print('sampler_idx =\n', sampler_idx)
# Test the Sampler
test.assertEqual(len(sampler_idx), 30)
batch_idx = np.array(sampler_idx).reshape(-1, 10)
for k in range(10):
test.assertEqual(np.diff(batch_idx[:, k], n=2).item(), 0)
sampler_idx = [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29]
Even though we're working with sequences, we can still use the standard PyTorch Dataset/DataLoader combo.
For the dataset we can use a built-in class, TensorDataset to return tuples of (sample, label)
from the samples and labels tensors we created above.
The DataLoader will be provided with our custom Sampler so that it generates appropriate batches.
import torch.utils.data
# Create DataLoader returning batches of samples.
batch_size = 32
ds_corpus = torch.utils.data.TensorDataset(samples, labels)
sampler_corpus = SequenceBatchSampler(ds_corpus, batch_size)
dl_corpus = torch.utils.data.DataLoader(ds_corpus, batch_size=batch_size, sampler=sampler_corpus, shuffle=False)
Let's see what that gives us:
print(f'num batches: {len(dl_corpus)}')
x0, y0 = next(iter(dl_corpus))
print(f'shape of a batch of samples: {x0.shape}')
print(f'shape of a batch of labels: {y0.shape}')
num batches: 3100 shape of a batch of samples: torch.Size([32, 64, 78]) shape of a batch of labels: torch.Size([32, 64])
Now lets look at the same sample index from multiple batches taken from our corpus.
# Check that sentences in in same index of different batches complete each other.
k = random.randrange(batch_size)
for j, (X, y) in enumerate(dl_corpus,):
print(f'=== batch {j}, sample {k} ({X[k].shape}): ===')
s = re.sub(r'\s+', ' ', unembed(X[k])).strip()
print(f'\t{s}')
if j==4: break
=== batch 0, sample 3 (torch.Size([64, 78])): === defective for requital Than we to stretch it out. Masters o === batch 1, sample 3 (torch.Size([64, 78])): === ' th' people, We do request your kindest ears; and, after, === batch 2, sample 3 (torch.Size([64, 78])): === Your loving motion toward the common body, To yield wha === batch 3, sample 3 (torch.Size([64, 78])): === t passes here. SICINIUS. We are convented Upon a pleasing === batch 4, sample 3 (torch.Size([64, 78])): === treaty, and have hearts Inclinable to honour and advance
Finally, our data set is ready so we can focus on our model.
We'll implement here is a multilayer gated recurrent unit (GRU) model, with dropout. This model is a type of RNN which performs similar to the well-known LSTM model, but it's somewhat easier to train because it has less parameters. We'll modify the regular GRU slightly by applying dropout to the hidden states passed between layers of the model.
The model accepts an input $\mat{X}\in\set{R}^{S\times V}$ containing a sequence of embedded chars. It returns an output $\mat{Y}\in\set{R}^{S\times V}$ of predictions for the next char and the final hidden state $\mat{H}\in\set{R}^{L\times H}$. Here $S$ is the sequence length, $V$ is the vocabulary size (number of unique chars), $L$ is the number of layers in the model and $H$ is the hidden dimension.
Mathematically, the model's forward function at layer $k\in[1,L]$ and timestep $t\in[1,S]$ can be described as
$$ \begin{align} \vec{z_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xz}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hz}}}^{[k]} + \vec{b}_{\mathrm{z}}^{[k]}\right) \\ \vec{r_t}^{[k]} &= \sigma\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xr}}}^{[k]} + \vec{h}_{t-1}^{[k]} {\mattr{W}_{\mathrm{hr}}}^{[k]} + \vec{b}_{\mathrm{r}}^{[k]}\right) \\ \vec{g_t}^{[k]} &= \tanh\left(\vec{x}^{[k]}_t {\mattr{W}_{\mathrm{xg}}}^{[k]} + (\vec{r_t}^{[k]}\odot\vec{h}_{t-1}^{[k]}) {\mattr{W}_{\mathrm{hg}}}^{[k]} + \vec{b}_{\mathrm{g}}^{[k]}\right) \\ \vec{h_t}^{[k]} &= \vec{z}^{[k]}_t \odot \vec{h}^{[k]}_{t-1} + \left(1-\vec{z}^{[k]}_t\right)\odot \vec{g_t}^{[k]} \end{align} $$The input to each layer is, $$ \mat{X}^{[k]} = \begin{bmatrix} {\vec{x}_1}^{[k]} \ \vdots \ {\vec{x}_S}^{[k]}
. $$
The output of the entire model is then, $$ \mat{Y} = \mat{X}^{[L+1]} {\mattr{W}_{\mathrm{hy}}} + \mat{B}_{\mathrm{y}} $$
and the final hidden state is $$ \mat{H} = \begin{bmatrix} {\vec{h}_S}^{[1]} \\ \vdots \\ {\vec{h}_S}^{[L]} \end{bmatrix}. $$
Notes:
Here's a graphical representation of the GRU's forward pass at each timestep. The $\vec{\tilde{h}}$ in the image is our $\vec{g}$ (candidate next state).

You can see how the reset and update gates allow the model to completely ignore it's previous state, completely ignore it's input, or any mixture of those states (since the gates are actually continuous and between $(0,1)$).
Here's a graphical representation of the entire model. You can ignore the $c_t^{[k]}$ (cell state) variables (which are relevant for LSTM models). Our model has only the hidden state, $h_t^{[k]}$. Also notice that we added dropout between layers (i.e., on the up arrows).

The purple tensors are inputs (a sequence and initial hidden state per layer), and the green tensors are outputs (another sequence and final hidden state per layer). Each blue block implements the above forward equations. Blocks that are on the same vertical level are at the same layer, and therefore share parameters.
TODO: Implement the MultilayerGRU class in the hw3/charnn.py module.
Notes:
in_dim = vocab_len
h_dim = 256
n_layers = 3
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers)
model = model.to(device)
print(model)
# Test forward pass
y, h = model(x0.to(dtype=torch.float, device=device))
print(f'y.shape={y.shape}')
print(f'h.shape={h.shape}')
test.assertEqual(y.shape, (batch_size, seq_len, vocab_len))
test.assertEqual(h.shape, (batch_size, n_layers, h_dim))
test.assertEqual(len(list(model.parameters())), 9 * n_layers + 2)
MultilayerGRU( (Layer_0_xz): Linear(in_features=78, out_features=256, bias=True) (Layer_0_xr): Linear(in_features=78, out_features=256, bias=True) (Layer_0_xg): Linear(in_features=78, out_features=256, bias=True) (Layer_0_hz): Linear(in_features=256, out_features=256, bias=False) (Layer_0_hr): Linear(in_features=256, out_features=256, bias=False) (Layer_0_hg): Linear(in_features=256, out_features=256, bias=False) (Layer_0_dropout): Dropout(p=0, inplace=False) (Layer_1_xz): Linear(in_features=256, out_features=256, bias=True) (Layer_1_xr): Linear(in_features=256, out_features=256, bias=True) (Layer_1_xg): Linear(in_features=256, out_features=256, bias=True) (Layer_1_hz): Linear(in_features=256, out_features=256, bias=False) (Layer_1_hr): Linear(in_features=256, out_features=256, bias=False) (Layer_1_hg): Linear(in_features=256, out_features=256, bias=False) (Layer_1_dropout): Dropout(p=0, inplace=False) (Layer_2_xz): Linear(in_features=256, out_features=256, bias=True) (Layer_2_xr): Linear(in_features=256, out_features=256, bias=True) (Layer_2_xg): Linear(in_features=256, out_features=256, bias=True) (Layer_2_hz): Linear(in_features=256, out_features=256, bias=False) (Layer_2_hr): Linear(in_features=256, out_features=256, bias=False) (Layer_2_hg): Linear(in_features=256, out_features=256, bias=False) (Layer_2_dropout): Dropout(p=0, inplace=False) (Output_layer): Linear(in_features=256, out_features=78, bias=True) ) y.shape=torch.Size([32, 64, 78]) h.shape=torch.Size([32, 3, 256])
Now that we have a model, we can implement text generation based on it. The idea is simple: At each timestep our model receives one char $x_t$ from the input sequence and outputs scores $y_t$ for what the next char should be. We'll convert these scores into a probability over each of the possible chars. In other words, for each input char $x_t$ we create a probability distribution for the next char conditioned on the current one and the state of the model (representing all previous inputs): $$p(x_{t+1}|x_t, \vec{h}_t).$$
Once we have such a distribution, we'll sample a char from it. This will be the first char of our generated sequence. Now we can feed this new char into the model, create another distribution, sample the next char and so on. Note that it's crucial to propagate the hidden state when sampling.
The important point however is how to create the distribution from the scores. One way, as we saw in previous ML tasks, is to use the softmax function. However, a drawback of softmax is that it can generate very diffuse (more uniform) distributions if the score values are very similar. When sampling, we would prefer to control the distributions and make them less uniform to increase the chance of sampling the char(s) with the highest scores compared to the others.
To control the variance of the distribution, a common trick is to add a hyperparameter $T$, known as the temperature to the softmax function. The class scores are simply scaled by $T$ before softmax is applied: $$ \mathrm{softmax}_T(\vec{y}) = \frac{e^{\vec{y}/T}}{\sum_k e^{y_k/T}} $$
A low $T$ will result in less uniform distributions and vice-versa.
TODO: Implement the hot_softmax() function in the hw3/charnn.py module.
scores = y[0,0,:].detach()
_, ax = plt.subplots(figsize=(15,5))
for t in reversed([0.3, 0.5, 1.0, 100]):
ax.plot(charnn.hot_softmax(scores, temperature=t).cpu().numpy(), label=f'T={t}')
ax.set_xlabel('$x_{t+1}$')
ax.set_ylabel('$p(x_{t+1}|x_t)$')
ax.legend()
uniform_proba = 1/len(char_to_idx)
uniform_diff = torch.abs(charnn.hot_softmax(scores, temperature=100) - uniform_proba)
test.assertTrue(torch.all(uniform_diff < 1e-4))
TODO: Implement the generate_from_model() function in the hw3/charnn.py module.
for _ in range(3):
text = charnn.generate_from_model(model, "foobar", 50, (char_to_idx, idx_to_char), T=0.5)
print(text)
test.assertEqual(len(text), 50)
foobarNXnWT3M3-4R[Ha.D'KlZ8l'h96sahFb&NwG5zt:MqM-w foobarSh7NuDY zLozpxU:SJI5qA9[ffjWk0kx1bUj4KLoM.Ln foobarl9qNsE8&e3t.-'03K:9RN]Kb[ 1CktB98;Yd)L-YoQlm
To train this model, we'll calculate the loss at each time step by comparing the predicted char to
the actual char from our label. We can use cross entropy since per char it's similar to a classification problem.
We'll then sum the losses over the sequence and back-propagate the gradients though time.
Notice that the back-propagation algorithm will "visit" each layer's parameter tensors multiple times,
so we'll accumulate gradients in parameters of the blocks. Luckily autograd will handle this part for us.
As usual, the first step of training will be to try and overfit a large model (many parameters) to a tiny dataset. Again, this is to ensure the model and training code are implemented correctly, i.e. that the model can learn.
For a generative model such as this, overfitting is slightly trickier than for classification. What we'll aim to do is to get our model to memorize a specific sequence of chars, so that when given the first char in the sequence it will immediately spit out the rest of the sequence verbatim.
Let's create a tiny dataset to memorize.
# Pick a tiny subset of the dataset
subset_start, subset_end = 1001, 1005
ds_corpus_ss = torch.utils.data.Subset(ds_corpus, range(subset_start, subset_end))
batch_size_ss = 1
sampler_ss = SequenceBatchSampler(ds_corpus_ss, batch_size=batch_size_ss)
dl_corpus_ss = torch.utils.data.DataLoader(ds_corpus_ss, batch_size_ss, sampler=sampler_ss, shuffle=False)
# Convert subset to text
subset_text = ''
for i in range(subset_end - subset_start):
subset_text += unembed(ds_corpus_ss[i][0])
print(f'Text to "memorize":\n\n{subset_text}')
Text to "memorize":
TRAM. What would you have?
HELENA. Something; and scarce so much; nothing, indeed.
I would not tell you what I would, my lord.
Faith, yes:
Strangers and foes do sunder and not kiss.
BERTRAM. I pray you, stay not, but in haste to horse.
HE
Now let's implement the first part of our training code.
TODO: Implement the train_epoch() and train_batch() methods of the RNNTrainer class in the hw3/training.py module.
You must think about how to correctly handle the hidden state of the model between batches and epochs for this specific task (i.e. text generation).
import torch.nn as nn
import torch.optim as optim
from hw3.training import RNNTrainer
torch.manual_seed(42)
lr = 0.01
num_epochs = 500
in_dim = vocab_len
h_dim = 128
n_layers = 2
loss_fn = nn.CrossEntropyLoss()
model = charnn.MultilayerGRU(in_dim, h_dim, out_dim=in_dim, n_layers=n_layers).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
trainer = RNNTrainer(model, loss_fn, optimizer, device)
for epoch in range(num_epochs):
epoch_result = trainer.train_epoch(dl_corpus_ss, verbose=False)
# Every X epochs, we'll generate a sequence starting from the first char in the first sequence
# to visualize how/if/what the model is learning.
if epoch == 0 or (epoch+1) % 25 == 0:
avg_loss = np.mean(epoch_result.losses)
accuracy = np.mean(epoch_result.accuracy)
print(f'\nEpoch #{epoch+1}: Avg. loss = {avg_loss:.3f}, Accuracy = {accuracy:.2f}%')
generated_sequence = charnn.generate_from_model(model, subset_text[0],
seq_len*(subset_end-subset_start),
(char_to_idx,idx_to_char), T=0.1)
# Stop if we've successfully memorized the small dataset.
print(generated_sequence)
if generated_sequence == subset_text:
break
# Test successful overfitting
test.assertGreater(epoch_result.accuracy, 99)
test.assertEqual(generated_sequence, subset_text)
Epoch #1: Avg. loss = 3.843, Accuracy = 17.58%
Tdt
Epoch #25: Avg. loss = 0.034, Accuracy = 99.61%
TRAM. What would you have?
HELENA. Something; and scarce so much; nothing, indeed.
I would not tell you what I would, my lord.
Faith, yes:
Strangers and foes do sunder and not kiss.
BERTRAM. I pray you, stay not, but in haste to horse.
HE
OK, so training works - we can memorize a short sequence. We'll now train a much larger model on our large dataset. You'll need a GPU for this part.
First, lets set up our dataset and models for training. We'll split our corpus into 90% train and 10% test-set. Also, we'll use a learning-rate scheduler to control the learning rate during training.
TODO: Set the hyperparameters in the part1_rnn_hyperparams() function of the hw3/answers.py module.
from hw3.answers import part1_rnn_hyperparams
hp = part1_rnn_hyperparams()
print('hyperparams:\n', hp)
### Dataset definition
vocab_len = len(char_to_idx)
batch_size = hp['batch_size']
seq_len = hp['seq_len']
train_test_ratio = 0.9
num_samples = (len(corpus) - 1) // seq_len
num_train = int(train_test_ratio * num_samples)
samples, labels = charnn.chars_to_labelled_samples(corpus, char_to_idx, seq_len, device)
ds_train = torch.utils.data.TensorDataset(samples[:num_train], labels[:num_train])
sampler_train = SequenceBatchSampler(ds_train, batch_size)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=False, sampler=sampler_train, drop_last=True)
ds_test = torch.utils.data.TensorDataset(samples[num_train:], labels[num_train:])
sampler_test = SequenceBatchSampler(ds_test, batch_size)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=False, sampler=sampler_test, drop_last=True)
print(f'Train: {len(dl_train):3d} batches, {len(dl_train)*batch_size*seq_len:7d} chars')
print(f'Test: {len(dl_test):3d} batches, {len(dl_test)*batch_size*seq_len:7d} chars')
### Training definition
in_dim = out_dim = vocab_len
checkpoint_file = 'checkpoints/rnn'
num_epochs = 50
early_stopping = 5
model = charnn.MultilayerGRU(in_dim, hp['h_dim'], out_dim, hp['n_layers'], hp['dropout'])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=hp['learn_rate'])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='max', factor=hp['lr_sched_factor'], patience=hp['lr_sched_patience'], verbose=True
)
trainer = RNNTrainer(model, loss_fn, optimizer, device)
hyperparams:
{'batch_size': 250, 'seq_len': 100, 'h_dim': 250, 'n_layers': 2, 'dropout': 0.01, 'learn_rate': 0.005, 'lr_sched_factor': 0.05, 'lr_sched_patience': 1}
Train: 228 batches, 5700000 chars
Test: 25 batches, 625000 chars
The code blocks below will train the model and save checkpoints containing the training state and the best model parameters to a file. This allows you to stop training and resume it later from where you left.
Note that you can use the main.py script provided within the assignment folder to run this notebook from the command line as if it were a python script by using the run-nb subcommand. This allows you to train your model using this notebook without starting jupyter. You can combine this with srun or sbatch to run the notebook with a GPU on the course servers.
TODO:
fit() method of the Trainer class. You can reuse the relevant implementation parts from HW2, but make sure to implement early stopping and checkpoints.test_epoch() and test_batch() methods of the RNNTrainer class in the hw3/training.py module.checkpoints/rnn_final.pt.
This will cause the block to skip training and instead load your saved model when running the homework submission script.
Note that your submission zip file will not include the checkpoint file. This is OK.from cs236781.plot import plot_fit
def post_epoch_fn(epoch, train_res, test_res, verbose):
# Update learning rate
scheduler.step(test_res.accuracy)
# Sample from model to show progress
if verbose:
start_seq = "ACT I."
generated_sequence = charnn.generate_from_model(
model, start_seq, 100, (char_to_idx,idx_to_char), T=0.5
)
print(generated_sequence)
# Train, unless final checkpoint is found
checkpoint_file_final = f'{checkpoint_file}_final.pt'
if os.path.isfile(checkpoint_file_final):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
saved_state = torch.load(checkpoint_file_final, map_location=device)
model.load_state_dict(saved_state['model_state'])
else:
try:
# Print pre-training sampling
print(charnn.generate_from_model(model, "ACT I.", 100, (char_to_idx,idx_to_char), T=0.5))
fit_res = trainer.fit(dl_train, dl_test, num_epochs, max_batches=None,
post_epoch_fn=post_epoch_fn, early_stopping=early_stopping,
checkpoints=checkpoint_file, print_every=1)
fig, axes = plot_fit(fit_res)
except KeyboardInterrupt as e:
print('\n *** Training interrupted by user')
ACT I.w?cz iX?'E64CVELCDz,Yv?K?f"F),)jS HyCm0M3CcX?]LtA?Tr[s,ELIz,Fk8:IvW)?"pIJ1?8??)b!ATqbj2hYC8cd? --- EPOCH 1/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 1
ACT I.
PRINCE. I will be not have mether, that the stand the fine
That make me for our singero
--- EPOCH 2/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 2
ACT I.
And shall be in him be but him will be they from
the song of the more commontry of th
--- EPOCH 3/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 3 ACT I. --- EPOCH 4/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 4
ACT I. He should be gone.
--- EPOCH 5/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 5
ACT I.
Ham. I have the stand be in my brother, and make your persons.
The man hath been a very
--- EPOCH 6/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 6 ACT I. Leon. Make me good morrow, good night; I will not be a brave fingers. Most will. I will --- EPOCH 7/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 7
ACT I. I say the King is now
and every truth, thy soul the great particular man's son's person,
--- EPOCH 8/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 8
ACT I. There's now the soul
be commended to show the subject of the wars of your spirits,
an
--- EPOCH 9/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 9
ACT I. Give me the truth,
And be of grace with such a house of the world,
When she is but th
--- EPOCH 10/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 10
ACT I. Well, we are more than the
good will be necessary of the humour of the music with his sha
--- EPOCH 11/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 11
ACT I.
Bora. O, my wife hath nothing live and bear a fear
his blood than my mistress' charge.
--- EPOCH 12/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 12 ACT I. Bene. Sir, she shall see them slain. Bene. If thou be shouldst thou had been a most princ --- EPOCH 13/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 13 ACT I. Ham. I will return the world when he of these thoughts. Hor. I speak with thee to me. H --- EPOCH 14/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 14
ACT I.
Bene. I know the more will I shall be but the rest of the stubborn.
--- EPOCH 15/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 15
ACT I.
Hot. I will not well for you. I have done the way
I would not have the mean to the prop
--- EPOCH 16/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 16
Epoch 16: reducing learning rate of group 0 to 2.5000e-04.
ACT I.
PORTIA. I have not not so much to follow my body.
But, what most heavenly part of the d
--- EPOCH 17/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 17
ACT I.
SIR TOBY. And the most contrary knows the cutternest man, and it was
as little in the n
--- EPOCH 18/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 18
ACT I. Exit.
--- EPOCH 19/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 19
ACT I. The fairest corn
I was more than the head, and the one of the sword
To seek their hea
--- EPOCH 20/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 20
ACT I.
DUKE SENIOR. Then were the time of this way speaks to any
And then to the gods for the
--- EPOCH 21/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 21 ACT I. [To Borachio.] Fal. Why, then a man have the world of the rock of thy sight. Fal. What a --- EPOCH 22/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 22
ACT I.
The greatest strength of this great calpable
That I may stand in thee for that with h
--- EPOCH 23/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 23
ACT I.
LAFEU. And then he was a king of many true love mine honour.
I mean, the which she hath
--- EPOCH 24/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 24
ACT I.
The more that was the grave of the wars of him.
DUKE. A man of good action, and poor th
--- EPOCH 25/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 25
ACT I.
Exit [Drum an
--- EPOCH 26/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 26
ACT I.
O Caesar! Why, the world is distracted to him.
What say you to you and a flower?
CA
--- EPOCH 27/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 27
ACT I.
I have done me with the like deserves of this.
--- EPOCH 28/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 28
ACT I.
CADE. I am a mile as they say; and so I did not so heart.
--- EPOCH 29/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 29 ACT I. COSTARD. This is the world mad man. PANDARUS. What is it thus far for the name of Sicilia --- EPOCH 30/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 30
ACT I.
The King, what name is the rest of this place?
What then? What news are grown on them
--- EPOCH 31/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 31
ACT I.
Mer. The other for a white world is like to be the man
and the land of a head of form.
--- EPOCH 32/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 32
ACT I.
Exit.
--- EPOCH 33/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 33
ACT I.
Prince. I shall have the skirt and letters of the sound of the
gentleman of a stranger
--- EPOCH 34/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 34 ACT I. Exit. SCENE II. The carping of the Emperor Enter HOSTESS an --- EPOCH 35/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 35
ACT I.
CAPHIS. I told you there! I have put on him that he would
do not be his own lady. Have
--- EPOCH 36/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 36 ACT I. GONZALO. He hath a thing is the great prince. VIOLA. I have seen the breaking of the fait --- EPOCH 37/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 37
ACT I.
Leon. So do I think the Duke of England, rot a word.
I would thou speak the point of su
--- EPOCH 38/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 38
ACT I.
What was he presently?
CAPTAIN. The Duke of Norfolk, if thou shalt not see him;
I w
--- EPOCH 39/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 39 ACT I. CLOWN. He is my lord's in the devil to the King. CRESSIDA. I will not know the house when --- EPOCH 40/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 40
Epoch 40: reducing learning rate of group 0 to 1.2500e-05.
ACT I.
What, what a fat fool speaks to me?
The heavens did solicit you, sir.
You are a c
--- EPOCH 41/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 41
ACT I.
I shall have some mind that I am too much enough.
I will not take the prince excuse o
--- EPOCH 42/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 42
ACT I.
What says the meaning of the court? There's some further
device?
Leon. What's the m
--- EPOCH 43/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 43
ACT I.
CADE. A man as I have no common single princely part of
the prince. What then?
SHAL
--- EPOCH 44/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 44
Epoch 44: reducing learning rate of group 0 to 6.2500e-07.
ACT I.
Exeunt.
Scene IV.
Lesi
--- EPOCH 45/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 45
ACT I.
The fairies was a time to be the word.
ACHILLES. What think'st thou, my lord?
MARCUS.
--- EPOCH 46/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 46
Epoch 46: reducing learning rate of group 0 to 3.1250e-08.
ACT I.
CRESSIDA. He that doth with the parts of his and least behaviour.
I do beseech you all
--- EPOCH 47/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 47 ACT I. COSTARD. I will provide thee not a little. PANDARUS. Good morrow, sir, the better. PARO --- EPOCH 48/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 48
Epoch 48: reducing learning rate of group 0 to 1.5625e-09.
ACT I.
Come, come, let's go.
--- EPOCH 49/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/rnn.pt at epoch 49
ACT I.
What, how shall I be commanded?
GONZALO. What are you?
PETRUCHIO. What say'st thou, T
--- EPOCH 50/50 ---
train_batch: 0%| | 0/228 [00:00<?, ?it/s]
test_batch: 0%| | 0/25 [00:00<?, ?it/s]
Armed with our fully trained model, let's generate the next Hamlet! You should experiment with modifying the sampling temperature and see what happens.
The text you generate should “look” like a Shakespeare play: old-style English words and sentence structure, directions for the actors (like “Exit/Enter”), sections (Act I/Scene III) etc. There will be no coherent plot of course, but it should at least seem like a Shakespearean play when not looking too closely. If this is not what you see, go back, debug and/or and re-train.
TODO: Specify the generation parameters in the part1_generation_params() function within the hw3/answers.py module.
from hw3.answers import part1_generation_params
start_seq, temperature = part1_generation_params()
generated_sequence = charnn.generate_from_model(
model, start_seq, 10000, (char_to_idx,idx_to_char), T=temperature
)
print(generated_sequence)
When she was just a girly that
her hands and three of my life, and such a devil. I should not say
the hand of all the shooting state when they have found
to the blood. I am a good love, where he is not their voices.
Exeunt.
Scene II.
A mother of the house of Gloucester's castle.
Enter Polonius, and Shallow. Exit [Laurence] and Cordelia, and Fortune's command].
Enter Montano, and Saint Lear,
Enter Cordelia.
Greg. Be thou the fall of heaven!
Exeunt.
Scene III.
A court and Bardolph.
Enter Paris to the Parloguus, and Lear and Claudio.
Enter Kent.
Osw. What is your will?
I will not hear the body of his service.
Exeunt.
Scene II.
A hollow within the King's nephew so strange.
Enter Don Pedro and Claudio.
Fal. I have a learned writ o' th' market-place, and thou wert
the form of her three, out of the world.
Prince. Why, she lies one that stands upon the court of the trumpet.
[Exit Peter.]
For then, my lord, and then I have made a prince
and be your father's foot to be the lady that have made
your company of many with all the world is she in the book
of the soul, and so are the world of the way to starve and honest man.
Beat. No.
Ham. I will not be found and wise, and so well the poor roaring of a
glorious end. Exeunt.
Scene III.
Another part of the field and the King
Pol. Where have I so?
Fal. Go to! Marry, my lord, my lord.
Prince. Who is the matter of our part?
Ham. So did I remember thee again. I thank you and my lord to the act.
Beat. I would prove the water-tomb bound to me be the services of the
court of the bottom of the manner of the best of the maids of
the sun with two sons, and his hands with the gates of his
villain. I have prizes me that the truth is dead, and let him see the
hand, and so be one. I am sure I believe the best of the present of
the poor opinions. Farewell. I will be with me to her.
Prince. Well, let me see my lord.
Peto. What is the matter?
Bene. What says he?
Pedro. Good Master Constable, and the more that I have said with the
strength of the mouth. He hath been consul, that thou art
so perfected and hath held my love; and he is to be my
reply.
Pedro. What a perfecter stand the man, you shall see you?
Bene. I do not draw the love of your command.
Beat. I will not answer to him. I have a bound and so hated a
great man.
Pedro. Why, then I am sure so swear thee.
Pedro. That is a good angel that I had the better than the
virtue of the contrary. What is your will?
Pedro. I will not prove upon you.
Bene. They say the time is the same instrument of the street in his
court. The fool will let them prove a man as they say; and there
was not a man when the gates of this same sun that she that
seem to be so valiant in the banks of my lord and my will.
Petruchio. He is very well said.
Prince. From the King, my lord, this day will be a man he will be the
hand.
Fal. I am sworn attendants, and the fair and the letter in the streets of
hard soldiers that they are sent for that we heard her beholding
worthy a father.
Prince. Here is the basket in his house. I will not speak with the book
of the town before the lady.
Prince. What, have you a great strife I would have me call your worship in this
sovereign?
Prince. I have done to the Senate have I should not shake him to you. I
will have the witness of it. If thou be a thousand thanks to
have my soul to the first for the throat, and the secrets of
it will be the day with the gods before the field and finds the stream
more than the most accusation of my mother, and he cannot
be old and come to be bound to my will.
Pedro. Then the world is more than the senators of the hands
of the three-pound. So do you so hard than a natural saint in the
wife, and the streets of the best of the worst that will have the
camp upon our proper sport, and the manner of the maids of his
discovery that is the gates of the master-souls of our means to
the gods to counterfeit the army of the common power
and see the tongue of the air. I know the table and come off a man of
the service, and a cardinal of my bosom, they are so.
Exeunt.
Scene III.
Stanley.
Enter The Palace.
Ham. This is the world to tell you himself and good night
I know the matter of the gates of thy face.
For that the speediest soldiers hath a thing
To such a service to his own desires,
And bear me up and fighting to desire.
And there is this submission of his bosom,
I am a man of heaven in the remembrance
And broke our father's flame to travel me.
Madam, I will not speak with me at friend.
Rom. So much in me a stranger with the bosom
With second time with a bark to the cause.
And to each other shall make haste before;
And what there is no matter for this court?
Who is not what thou dost before me?
The bridegroom of his death to me into the state,
With fair and rash and dangerous confession.
What means this to a death and poor and faces?
Ben. The part of the strong isle sits in his part,
And in the boy of Him that should have seen.
The same of them that is a word or stark.
I love thee in a stranger, that the world
Were such a better summer lives in heaven.
Romeo will be the better that were there,
And I will find a mountain town and letters
The sun with thoughts of such a sea to thee.
I see the court I took the ground.
I hope the letter, if thou wert a bloody
To make thee like a father's wife with this,
To see them call me truth. The realm so can
That they desire to see the world in state
But that the sun are then have found my presence.
Let us not be a prince's body and fair earth,
And this the bed that men have been to read
The first of their officers with my country,
And will she say the news, and then a woman's death.
My son is sound and straight from them to see.
Without a power to be my heart before,
And then thou hast the fairest prosperous gods,
And God help the other of the devil.
What consequence, the foul soul of the crown?
I think some constant state and life to say 'I was,
Did then the belly book on him at home,
Which they that show'd his father with their blood,
As this should stay with all the rage of war.
All loves and good contempt, stand as the senators,
The sun of your brother being here a father
That it is now a child shall think the King.
Alb. Set thy hands and shallow it that I shall shake.
[Exit Tom.
Exeunt.
Scene II.
Sorversete of the castle.
Enter Messengers, Angelo, and Beatrice.
Friar. Let no more be the lady of the King, and so conduct
the great company.
Ham. I will do it.
Pol. I know not, good my lord.
Bene. I would the fine hand here to be done.
Ham. Ay, for he hath forgot the man of the cause of charity!
What, will you come to thy charge? He is a woman so.
God save you, sir, he hath his beauty to be bound.
Ham. Ay, and see it in my land.
Exeunt.
Scene III.
Elsinore. A street
Enter Edgar.
Lear. And there is no man so betrayed as the state
With the contents of the commons of the King,
That she shall be so far.
Exeunt.
Scene II.
Elsinore. A street
Enter Antonio and Soldiers, and others doth the King.
King. Why, then I shall be so best love.
Fran. Go tall what we conceive the name of this fac'd
honourable and destruction. This man is a punishment,
so in this shepherd, and to stand and fight at my own soul
Is not so true as they are set down a care.
Ham. No, thou art angry.
Hot. The same state shall be so, he was not like a man.
Exit.
Gon. My wife were nothing entertainment, but I
than merry than the gods are at his head. If thou be fair, and
should be so well as much as the sun will make the compass of the
orchard for the door with the chain of his accusations. I must confess
the law to strike the late fingers of the particular, and the King's poor
three or five of mine own proper man.
Edm. I know not what I have made the word of the proof.
Lear. What's the man of the part? Why, the best and such a thousand phease
that he is a rover- for a story of a devil and the dog
that they say.
Osr. You are a man.
Claud. I know not. I will be in his writing to my wife and make it off
the world in the streets. I am a beast of it in the sun that you shall
think this will see you and the priest of the maids of a fair
country.
Pedro. I am glad to speak with you. If I love you, sir; so I do not know
the sea to the city.
Ham. I would I do not say the gods stand for that your part.
Ham. What can you proceed in her father.
Bene. This is not with a world in the matter.
Prince. Are you not a poor and the truth with the least of the town? What
thing is the time of his country are the compounded with the robb'd court?
Good Master Gower, you shine!
Prince. I am a man, an'
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers
Why do we split the corpus into sequences instead of training on the whole text?
display_answer(hw3.answers.part1_q1)
Your answer:
We have a large corpus, uploading the entire corpus at once onto the machine requires large memory resources, which slows down the training procss.
We avoid this phenomenon by breaking the corpus down into small parts and load one part at a time to the memory.
How is it possible that the generated text clearly shows memory longer than the sequence length?
display_answer(hw3.answers.part1_q2)
Your answer: Memory in RNN's is derived by the hidden state's ability to predict the next word in the sequence itwas trained on. Our network demonstrates a longer memory capability than just the lengths of the sequences we trained it on, because the hidden state learns the interconnections between sequences and genarlizes to te entire corpus.
Why are we not shuffling the order of batches when training?
display_answer(hw3.answers.part1_q3)
Your answer: We don't mix the order of the batches when training because we want to train the modules in the correct order. Training the modules according to the correct order ensures keeping a correct and logical relationship between the sentences. Additionally, it takes context into accont. This helps our module in generating a text which resembles the original text.
display_answer(hw3.answers.part1_q4)
Your answer:
a. We lower the temprature for the model to make the conditional distribution of the next word givn the current one as dissimilar to uniform distribution as possible. If the distribution were indeed uniform then taking maximum argument as criterion will yield very unpredictable and thus uninformative results.
b. Probability over the output with temparature T defined as $ e^{y_i/T} / \sum{e^{y_i/T}} $ If T is very large than the exponent is very close to 0, then the numerator will be around 1 and denominator around n, then for any output we obtain a distribution similar to uniform distribution.
c. Using a very low temperature means that the variance of the distribution is also small. This means the the model would be very far from a uniform distribution. As a cosequence to that, the generated model would choose only that chars that it's certain about, without taking any risks in choosing other chars. This would yield corpus with a very constrained number of chars, becuase the other chars didn't have a chance of being picked by the model.
In this part we will learn to generate new data using a special type of autoencoder model which allows us to sample from its latent space. We'll implement and train a VAE and use it to generate new images.
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cuda
Let's begin by downloading a dataset of images that we want to learn to generate. We'll use the Labeled Faces in the Wild (LFW) dataset which contains many labeled faces of famous individuals.
We're going to train our generative model to generate a specific face, not just any face. Since the person with the most images in this dataset is former president George W. Bush, we'll set out to train a Bush Generator :)
However, if you feel adventurous and/or prefer to generate something else, feel free
to edit the PART2_CUSTOM_DATA_URL variable in hw3/answers.py.
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART2_CUSTOM_DATA_URL as CUSTOM_DATA_URL
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
DATA_URL = CUSTOM_DATA_URL
_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/snirhordan/.pytorch-datasets/lfw-bush.zip exists, skipping download. Extracting /home/snirhordan/.pytorch-datasets/lfw-bush.zip... Extracted 531 to /home/snirhordan/.pytorch-datasets/lfw/George_W_Bush
Create a Dataset object that will load the extraced images:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
im_size = 64
tf = T.Compose([
# Resize to constant spatial dimensions
T.Resize((im_size, im_size)),
# PIL.Image -> torch.Tensor
T.ToTensor(),
# Dynamic range [0,1] -> [-1, 1]
T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])
ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)
OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)
test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])
An autoencoder is a model which learns a representation of data in an unsupervised fashion (i.e without any labels). Recall it's general form from the lecture:

An autoencoder maps an instance $\bb{x}$ to a latent-space representation $\bb{z}$. It has an encoder part, $\Phi_{\bb{\alpha}}(\bb{x})$ (a model with parameters $\bb{\alpha}$) and a decoder part, $\Psi_{\bb{\beta}}(\bb{z})$ (a model with parameters $\bb{\beta}$).
While autoencoders can learn useful representations, generally it's hard to use them as generative models because there's no distribution we can sample from in the latent space. In other words, we have no way to choose a point $\bb{z}$ in the latent space such that $\Psi(\bb{z})$ will end up on the data manifold in the instance space.

The variational autoencoder (VAE), first proposed by Kingma and Welling, addresses this issue by taking a probabilistic perspective. Briefly, a VAE model can be described as follows.
We define, in Baysean terminology,
To create our variational decoder we'll further specify:
This setting allows us to generate a new instance $\bb{x}$ by sampling $\bb{z}$ from the multivariate normal distribution, obtaining the instance-space mean $\Psi _{\bb{\beta}}(\bb{z})$ using our decoder network, and then sampling $\bb{x}$ from $\mathcal{N}( \Psi _{\bb{\beta}}(\bb{z}) , \sigma^2 \bb{I} )$.
Our variational encoder will approximate the posterior with a parametric distribution $q _{\bb{\alpha}}(\bb{Z} | \bb{x}) = \mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$. The interpretation is that our encoder model, $\Phi_{\vec{\alpha}}(\bb{x})$, calculates the mean and variance of the posterior distribution, and samples $\bb{z}$ based on them. An important nuance here is that our network can't contain any stochastic elements that depend on the model parameters, otherwise we won't be able to back-propagate to those parameters. So sampling $\bb{z}$ from $\mathcal{N}( \bb{\mu} _{\bb{\alpha}}(\bb{x}), \mathrm{diag}\{ \bb{\sigma}^2_{\bb{\alpha}}(\bb{x}) \} )$ is not an option. The solution is to use what's known as the reparametrization trick: sample from an isotropic Gaussian, i.e. $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ (which doesn't depend on trainable parameters), and calculate the latent representation as $\bb{z} = \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{u}\odot\bb{\sigma}_{\bb{\alpha}}(\bb{x})$.
To train a VAE model, we maximize the evidence distribution, $p(\bb{X})$ (see question below). The VAE loss can therefore be stated as minimizing $\mathcal{L} = -\mathbb{E}_{\bb{x}} \log p(\bb{X})$. Although this expectation is intractable, we can obtain a lower-bound for $p(\bb{X})$ (the evidence lower bound, "ELBO", shown in the lecture):
$$ \log p(\bb{X}) \ge \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} }\left[ \log p _{\bb{\beta}}(\bb{X} | \bb{z}) \right] - \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{X})\,\left\|\, p(\bb{Z} )\right.\right) $$where $ \mathcal{D} _{\mathrm{KL}}(q\left\|\right.p) = \mathbb{E}_{\bb{z}\sim q}\left[ \log \frac{q(\bb{Z})}{p(\bb{Z})} \right] $ is the Kullback-Liebler divergence, which can be interpreted as the information gained by using the posterior $q(\bb{Z|X})$ instead of the prior distribution $p(\bb{Z})$.
Using the ELBO, the VAE loss becomes, $$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} {\bb{x}} \left[ \mathbb{E} {\bb{z} \sim q {\bb{\alpha}} }\left[ -\log p {\bb{\beta}}(\bb{x} | \bb{z}) \right]
By remembering that the likelihood is a Gaussian distribution with a diagonal covariance and by applying the reparametrization trick, we can write the above as
$$ \mathcal{L}(\vec{\alpha},\vec{\beta}) = \mathbb{E} _{\bb{x}} \left[ \mathbb{E} _{\bb{z} \sim q _{\bb{\alpha}} } \left[ \frac{1}{2\sigma^2}\left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 \right] + \mathcal{D} _{\mathrm{KL}}\left(q _{\bb{\alpha}}(\bb{Z} | \bb{x})\,\left\|\, p(\bb{Z} )\right.\right) \right]. $$Obviously our model will have two parts, an encoder and a decoder. Since we're working with images, we'll implement both as deep convolutional networks, where the decoder is a "mirror image" of the encoder implemented with adjoint (AKA transposed) convolutions. Between the encoder CNN and the decoder CNN we'll implement the sampling from the parametric posterior approximator $q_{\bb{\alpha}}(\bb{Z}|\bb{x})$ to make it a VAE model and not just a regular autoencoder (of course, this is not yet enough to create a VAE, since we also need a special loss function which we'll get to later).
First let's implement just the CNN part of the Encoder network (this is not the full $\Phi_{\vec{\alpha}}(\bb{x})$ yet). As usual, it should take an input image and map to a activation volume of a specified depth. We'll consider this volume as the features we extract from the input image. Later we'll use these to create the latent space representation of the input.
import hw3.autoencoder as autoencoder
in_channels = 3
out_channels = 1024
encoder_cnn = autoencoder.EncoderCNN(in_channels, out_channels).to(device)
print(encoder_cnn)
h = encoder_cnn(x0)
print(h.shape)
test.assertEqual(h.dim(), 4)
test.assertSequenceEqual(h.shape[0:2], (1, out_channels))
EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
(10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
)
)
torch.Size([1, 1024, 5, 5])
Now let's implement the CNN part of the Decoder.
Again this is not yet the full $\Psi _{\bb{\beta}}(\bb{z})$. It should take an activation volume produced
by your EncoderCNN and output an image of the same dimensions as the Encoder's input was.
This can be a CNN which is like a "mirror image" of the the Encoder. For example, replace convolutions with transposed convolutions, downsampling with up-sampling etc.
Consult the documentation of ConvTranspose2D
to figure out how to reverse your convolutional layers in terms of input and output dimensions. Note that the decoder doesn't have to be exactly the opposite of the encoder and you can experiment with using a different architecture.
TODO: Implement the DecoderCNN class in the hw3/autoencoder.py module.
decoder_cnn = autoencoder.DecoderCNN(in_channels=out_channels, out_channels=in_channels).to(device)
print(decoder_cnn)
x0r = decoder_cnn(h)
print(x0r.shape)
test.assertEqual(x0.shape, x0r.shape)
# Should look like colored noise
T.functional.to_pil_image(x0r[0].cpu().detach())
DecoderCNN(
(cnn): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
(10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
torch.Size([1, 3, 64, 64])
Let's now implement the full VAE Encoder, $\Phi_{\vec{\alpha}}(\vec{x})$. It will work as follows:
\bb{\mu} _{\bb{\alpha}}(\bb{x}) &= \vec{h}\mattr{W}_{\mathrm{h\mu}} + \vec{b}_{\mathrm{h\mu}} \\
\log\left(\bb{\sigma}^2_{\bb{\alpha}}(\bb{x})\right) &= \vec{h}\mattr{W}_{\mathrm{h\sigma^2}} + \vec{b}_{\mathrm{h\sigma^2}}
\end{align}
$$Notice that we model the log of the variance, not the actual variance. The above formulation is proposed in appendix C of the VAE paper.
TODO: Implement the encode() method in the VAE class within the hw3/autoencoder.py module.
You'll also need to define your parameters in __init__().
z_dim = 2
vae = autoencoder.VAE(encoder_cnn, decoder_cnn, x0[0].size(), z_dim).to(device)
print(vae)
z, mu, log_sigma2 = vae.encode(x0)
test.assertSequenceEqual(z.shape, (1, z_dim))
test.assertTrue(z.shape == mu.shape == log_sigma2.shape)
print(f'mu(x0)={list(*mu.detach().cpu().numpy())}, sigma2(x0)={list(*torch.exp(log_sigma2).detach().cpu().numpy())}')
VAE(
(features_encoder): EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1))
(10): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
)
)
(features_decoder): DecoderCNN(
(cnn): Sequential(
(0): ConvTranspose2d(1024, 512, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
(10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(log): Linear(in_features=25600, out_features=2, bias=True)
(reconstruct): Linear(in_features=2, out_features=25600, bias=True)
(mu): Linear(in_features=25600, out_features=2, bias=True)
)
mu(x0)=[-0.1058653, 0.08127165], sigma2(x0)=[1.9030428, 0.90913814]
Let's sample some 2d latent representations for an input image x0 and visualize them.
# Sample from q(Z|x)
N = 500
Z = torch.zeros(N, z_dim)
_, ax = plt.subplots()
with torch.no_grad():
for i in range(N):
Z[i], _, _ = vae.encode(x0)
ax.scatter(*Z[i].cpu().numpy())
# Should be close to the mu/sigma in the previous block above
print('sampled mu', torch.mean(Z, dim=0))
print('sampled sigma2', torch.var(Z, dim=0))
sampled mu tensor([-0.0600, 0.0384]) sampled sigma2 tensor([3.6430, 0.7947])
Let's now implement the full VAE Decoder, $\Psi _{\bb{\beta}}(\bb{z})$. It will work as follows:
TODO: Implement the decode() method in the VAE class within the hw3/autoencoder.py module.
You'll also need to define your parameters in __init__(). You may need to also re-run the block above after you implement this.
x0r = vae.decode(z)
test.assertSequenceEqual(x0r.shape, x0.shape)
Our model's forward() function will simply return decode(encode(x)) as well as the calculated mean and log-variance of the posterior.
x0r, mu, log_sigma2 = vae(x0)
test.assertSequenceEqual(x0r.shape, x0.shape)
test.assertSequenceEqual(mu.shape, (1, z_dim))
test.assertSequenceEqual(log_sigma2.shape, (1, z_dim))
T.functional.to_pil_image(x0r[0].detach().cpu())
In practice, since we're using SGD, we'll drop the expectation over $\bb{X}$ and instead sample an instance from the training set and compute a point-wise loss. Similarly, we'll drop the expectation over $\bb{Z}$ by sampling from $q_{\vec{\alpha}}(\bb{Z}|\bb{x})$. Additionally, because the KL divergence is between two Gaussian distributions, there is a closed-form expression for it. These points bring us to the following point-wise loss:
$$ \ell(\vec{\alpha},\vec{\beta};\bb{x}) = \frac{1}{\sigma^2 d_x} \left\| \bb{x}- \Psi _{\bb{\beta}}\left( \bb{\mu} _{\bb{\alpha}}(\bb{x}) + \bb{\Sigma}^{\frac{1}{2}} _{\bb{\alpha}}(\bb{x}) \bb{u} \right) \right\| _2^2 + \mathrm{tr}\,\bb{\Sigma} _{\bb{\alpha}}(\bb{x}) + \|\bb{\mu} _{\bb{\alpha}}(\bb{x})\|^2 _2 - d_z - \log\det \bb{\Sigma} _{\bb{\alpha}}(\bb{x}), $$where $d_z$ is the dimension of the latent space, $d_x$ is the dimension of the input and $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$. This pointwise loss is the quantity that we'll compute and minimize with gradient descent. The first term corresponds to the data-reconstruction loss, while the second term corresponds to the KL-divergence loss. Note that the scaling by $d_x$ is not derived from the original loss formula and was added directly to the pointwise loss just to normalize the data term.
TODO: Implement the vae_loss() function in the hw3/autoencoder.py module.
from hw3.autoencoder import vae_loss
torch.manual_seed(42)
def test_vae_loss():
# Test data
N, C, H, W = 10, 3, 64, 64
z_dim = 32
x = torch.randn(N, C, H, W)*2 - 1
xr = torch.randn(N, C, H, W)*2 - 1
z_mu = torch.randn(N, z_dim)
z_log_sigma2 = torch.randn(N, z_dim)
x_sigma2 = 0.9
loss, _, _ = vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
test.assertAlmostEqual(loss.item(), 58.3234367, delta=1e-3)
return loss
test_vae_loss()
tensor(58.3234)
The main advantage of a VAE is that it can by used as a generative model by sampling the latent space, since we optimize for a isotropic Gaussian prior $p(\bb{Z})$ in the loss function. Let's now implement this so that we can visualize how our model is doing when we train.
TODO: Implement the sample() method in the VAE class within the hw3/autoencoder.py module.
samples = vae.sample(5)
_ = plot.tensors_as_images(samples)
Time to train!
TODO:
VAETrainer class in the hw3/training.py module. Make sure to implement the checkpoints feature of the Trainer class if you haven't done so already in Part 1.part2_vae_hyperparams() function within the hw3/answers.py module.import torch.optim as optim
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn import DataParallel
from hw3.training import VAETrainer
from hw3.answers import part2_vae_hyperparams
torch.manual_seed(42)
# Hyperparams
hp = part2_vae_hyperparams()
batch_size = hp['batch_size']
h_dim = hp['h_dim']
z_dim = hp['z_dim']
x_sigma2 = hp['x_sigma2']
learn_rate = hp['learn_rate']
betas = hp['betas']
# Data
split_lengths = [int(len(ds_gwb)*0.9), int(len(ds_gwb)*0.1)]
ds_train, ds_test = random_split(ds_gwb, split_lengths)
dl_train = DataLoader(ds_train, batch_size, shuffle=True)
dl_test = DataLoader(ds_test, batch_size, shuffle=True)
im_size = ds_train[0][0].shape
# Model
encoder = autoencoder.EncoderCNN(in_channels=im_size[0], out_channels=h_dim)
decoder = autoencoder.DecoderCNN(in_channels=h_dim, out_channels=im_size[0])
vae = autoencoder.VAE(encoder, decoder, im_size, z_dim)
vae_dp = DataParallel(vae).to(device)
# Optimizer
optimizer = optim.Adam(vae.parameters(), lr=learn_rate, betas=betas)
# Loss
def loss_fn(x, xr, z_mu, z_log_sigma2):
return autoencoder.vae_loss(x, xr, z_mu, z_log_sigma2, x_sigma2)
# Trainer
trainer = VAETrainer(vae_dp, loss_fn, optimizer, device)
checkpoint_file = 'checkpoints/vae'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
os.remove(f'{checkpoint_file}.pt')
# Show model and hypers
print(vae)
print(hp)
VAE(
(features_encoder): EncoderCNN(
(cnn): Sequential(
(0): Conv2d(3, 64, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): Conv2d(64, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): Conv2d(128, 512, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(10): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU()
)
)
(features_decoder): DecoderCNN(
(cnn): Sequential(
(0): ConvTranspose2d(512, 512, kernel_size=(5, 5), stride=(2, 2))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU()
(3): ConvTranspose2d(512, 128, kernel_size=(5, 5), stride=(2, 2))
(4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU()
(6): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
(7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU()
(9): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(1, 1))
(10): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(log): Linear(in_features=12800, out_features=256, bias=True)
(reconstruct): Linear(in_features=256, out_features=12800, bias=True)
(mu): Linear(in_features=12800, out_features=256, bias=True)
)
{'batch_size': 32, 'h_dim': 512, 'z_dim': 256, 'x_sigma2': 0.00095, 'learn_rate': 9e-05, 'betas': (0.99, 0.998)}
TODO:
_final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.The images you get should be colorful, with different backgrounds and poses.
import IPython.display
def post_epoch_fn(epoch, train_result, test_result, verbose):
# Plot some samples if this is a verbose epoch
if verbose:
samples = vae.sample(n=5)
fig, _ = plot.tensors_as_images(samples, figsize=(6,2))
IPython.display.display(fig)
plt.close(fig)
if os.path.isfile(f'{checkpoint_file_final}.pt'):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
checkpoint_file = checkpoint_file_final
else:
res = trainer.fit(dl_train, dl_test,
num_epochs=200, early_stopping=20, print_every=10,
checkpoints=checkpoint_file,
post_epoch_fn=post_epoch_fn)
# Plot images from best model
saved_state = torch.load(f'{checkpoint_file}.pt', map_location=device)
vae_dp.load_state_dict(saved_state['model_state'])
print('*** Images Generated from best model:')
fig, _ = plot.tensors_as_images(vae_dp.module.sample(n=15), nrows=3, figsize=(6,6))
--- EPOCH 1/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 1
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 2
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 3
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 4
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 5
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 6
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 7
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 8
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 9
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 10 --- EPOCH 11/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 11
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 12
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 13
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 14
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 15
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 16
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 17
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 18
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 19
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 20 --- EPOCH 21/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 21
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 22
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 23
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 24
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 25
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 26
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 27
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 28
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 29
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 30 --- EPOCH 31/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 31
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 32
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 33
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 34
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 35
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 36
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 37
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 38
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 39
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 40 --- EPOCH 41/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 41
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 42
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 43
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 44
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 45
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 46
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 47
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 48
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 49
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 50 --- EPOCH 51/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 51
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 52
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 53
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 54
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 55
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 56
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 57
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 58
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 59
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 60 --- EPOCH 61/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 61
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 62
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 63
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 64
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 65
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 66
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 67
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 68
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 69
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 70 --- EPOCH 71/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 71
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 72
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 73
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 74
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 75
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 76
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 77
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 78
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 79
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 80 --- EPOCH 81/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 81
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 82
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 83
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 84
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 85
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 86
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 87
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 88
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 89
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 90 --- EPOCH 91/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 91
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 92
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 93
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 94
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 95
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 96
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 97
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 98
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 99
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 100 --- EPOCH 101/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 101
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 102
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 103
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 104
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 105
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 106
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 107
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 108
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 109
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 110 --- EPOCH 111/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 111
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 112
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 113
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 114
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 115
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 116
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 117
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 118
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 119
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 120 --- EPOCH 121/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 121
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 122
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 123
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 124
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 125
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 126
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 127
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 128
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 129
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 130 --- EPOCH 131/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 131
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 132
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 133
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 134
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 135
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 136
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 137
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 138
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 139
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 140 --- EPOCH 141/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 141
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 142
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 143
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 144
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 145
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 146
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 147
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 148
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 149
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 150 --- EPOCH 151/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 151
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 152
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 153
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 154
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 155
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 156
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 157
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 158
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 159
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 160 --- EPOCH 161/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 161
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 162
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 163
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 164
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 165
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 166
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 167
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 168
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 169
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 170 --- EPOCH 171/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 171
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 172
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 173
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 174
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 175
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 176
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 177
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 178
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 179
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 180 --- EPOCH 181/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 181
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 182
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 183
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 184
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 185
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 186
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 187
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 188
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 189
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 190 --- EPOCH 191/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 191
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 192
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 193
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 194
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 195
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 196
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 197
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 198
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 199 --- EPOCH 200/200 ---
train_batch: 0%| | 0/15 [00:00<?, ?it/s]
test_batch: 0%| | 0/2 [00:00<?, ?it/s]
*** Saved checkpoint checkpoints/vae.pt at epoch 200
*** Images Generated from best model:
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers as answers
What does the $\sigma^2$ hyperparameter (x_sigma2 in the code) do? Explain the effect of low and high values.
display_answer(answers.part2_q1)
Your answer: The hyperparameter $\sigma^2$ is used to set the distance between the encoding and the mean (describes the allowed difference between the distance and the mean.) By using low sigma values, the images generated by the model are closer to the training data, that's because the model is closer to the mean and is more constrained by the data it has seen. This is in contrast to using high sigma values, which may produce images that differ from the learned data.
display_answer(answers.part2_q2)
Your answer: 1)Reconstruction Loss: Gives us a measure of how well the decoder reconstructs x. KL divergence loss: is a regularizer that measures how much information we lose when using q to represent p.
2)The effect of the KL loss on the latent-space distribution is as follows: the KL loss changes z_mu and z_sigma_2 given an instance of x by penalising the model to an inferior distribution of z.
3) The benefit of this effect lies in the improvement of the generation task, because it adds interpolations between classes and remove dicontinuities in the latent-space.
In the formulation of the VAE loss, why do we start by maximizing the evidence distribution, $p(\bb{X})$?
display_answer(answers.part2_q3)
Your answer: In the formulation of the VAE loss, we start by maximizing the evidence distribution, $p(\bb{X})$ because this helps us in finding the probability distrubuion of the data. This means that maximizing $p(\bb{X})$ gives a propper aproximation of the actual distribuation of the data.
In the VAE encoder, why do we model the log of the latent-space variance corresponding to an input, $\bb{\sigma}^2_{\bb{\alpha}}$, instead of directly modelling this variance?
display_answer(answers.part2_q4)
Your answer: We use the log function here because we want to change the problem from a multiplication of all the the probabilities to a summation of the lof of all of those probabilities. We can use the log because it is monitinically ascending, and so the maximal value won't change. We can assume that because each data we got to train the model, is sampeld by the actual distibution.
In this part we will implement and train a generative adversarial network and apply it to the task of image generation.
import unittest
import os
import sys
import pathlib
import urllib
import shutil
import re
import zipfile
import numpy as np
import torch
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2
test = unittest.TestCase()
plt.rcParams.update({'font.size': 12})
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
Using device: cpu
We'll use the same data as in Part 2.
But again, you can use a custom dataset, by editing the PART3_CUSTOM_DATA_URL variable in hw3/answers.py.
import cs236781.plot as plot
import cs236781.download
from hw3.answers import PART3_CUSTOM_DATA_URL as CUSTOM_DATA_URL
DATA_DIR = pathlib.Path.home().joinpath('.pytorch-datasets')
if CUSTOM_DATA_URL is None:
DATA_URL = 'http://vis-www.cs.umass.edu/lfw/lfw-bush.zip'
else:
DATA_URL = CUSTOM_DATA_URL
_, dataset_dir = cs236781.download.download_data(out_path=DATA_DIR, url=DATA_URL, extract=True, force=False)
File /home/kali/.pytorch-datasets/lfw-bush.zip exists, skipping download. Extracting /home/kali/.pytorch-datasets/lfw-bush.zip... Extracted 531 to /home/kali/.pytorch-datasets/lfw/George_W_Bush
Create a Dataset object that will load the extraced images:
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
im_size = 64
tf = T.Compose([
# Resize to constant spatial dimensions
T.Resize((im_size, im_size)),
# PIL.Image -> torch.Tensor
T.ToTensor(),
# Dynamic range [0,1] -> [-1, 1]
T.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5)),
])
ds_gwb = ImageFolder(os.path.dirname(dataset_dir), tf)
OK, let's see what we got. You can run the following block multiple times to display a random subset of images from the dataset.
_ = plot.dataset_first_n(ds_gwb, 50, figsize=(15,10), nrows=5)
print(f'Found {len(ds_gwb)} images in dataset folder.')
Found 530 images in dataset folder.
x0, y0 = ds_gwb[0]
x0 = x0.unsqueeze(0).to(device)
print(x0.shape)
test.assertSequenceEqual(x0.shape, (1, 3, im_size, im_size))
torch.Size([1, 3, 64, 64])
GANs, first proposed in a paper by Ian Goodfellow in 2014 are today arguably the most popular type of generative model. GANs are currently producing state of the art results in generative tasks over many different domains.
In a GAN model, two different neural networks compete against each other: A generator and a discriminator.
The Generator, which we'll denote as $\Psi _{\bb{\gamma}} : \mathcal{U} \rightarrow \mathcal{X}$, maps a latent-space variable $\bb{u}\sim\mathcal{N}(\bb{0},\bb{I})$ to an instance-space variable $\bb{x}$ (e.g. an image). Thus a parametric evidence distribution $p_{\bb{\gamma}}(\bb{X})$ is generated, which we typically would like to be as close as possible to the real evidence distribution, $p(\bb{X})$.
The Discriminator, $\Delta _{\bb{\delta}} : \mathcal{X} \rightarrow [0,1]$, is a network which, given an instance-space variable $\bb{x}$, returns the probability that $\bb{x}$ is real, i.e. that $\bb{x}$ was sampled from $p(\bb{X})$ and not $p_{\bb{\gamma}}(\bb{X})$.

The generator is trained to generate "fake" instances which will maximally fool the discriminator into returning that they're real. Mathematically, the generator's parameters $\bb{\gamma}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
The discriminator is trained to classify between real images, coming from the training set, and fake images generated by the generator. Mathematically, the discriminator's parameters $\bb{\delta}$ should be chosen such as to maximize the expression $$ \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
These two competing objectives can thus be expressed as the following min-max optimization: $$ \min _{\bb{\gamma}} \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$
A key insight into GANs is that we can interpret the above maximum as the loss with respect to $\bb{\gamma}$:
$$ L({\bb{\gamma}}) = \max _{\bb{\delta}} \, \mathbb{E} _{\bb{x} \sim p(\bb{X}) } \log \Delta _{\bb{\delta}}(\bb{x}) \, + \, \mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (1-\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )). $$This means that the generator's loss function trains together with the generator itself in an adversarial manner. In contrast, when training our VAE we used a fixed L2 norm as a data loss term.
We'll now implement a Deep Convolutional GAN (DCGAN) model. See the DCGAN paper for architecture ideas and tips for training.
TODO: Implement the Discriminator class in the hw3/gan.py module.
If you wish you can reuse the EncoderCNN class from the VAE model as the first part of the Discriminator.
import hw3.gan as gan
dsc = gan.Discriminator(in_size=x0[0].shape).to(device)
print(dsc)
d0 = dsc(x0)
print(d0.shape)
test.assertSequenceEqual(d0.shape, (1,1))
Discriminator(
(dsc): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
)
)
torch.Size([1, 1])
TODO: Implement the Generator class in the hw3/gan.py module.
If you wish you can reuse the DecoderCNN class from the VAE model as the last part of the Generator.
z_dim = 128
gen = gan.Generator(z_dim, 4).to(device)
print(gen)
z = torch.randn(1, z_dim).to(device)
xr = gen(z)
print(xr.shape)
test.assertSequenceEqual(x0.shape, xr.shape)
Generator(
(cnn): Sequential(
(0): ConvTranspose2d(128, 512, kernel_size=(4, 4), stride=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(13): Tanh()
)
)
torch.Size([1, 3, 64, 64])
Let's begin with the discriminator's loss function. Based on the above we can flip the sign and say we want to update the Discriminator's parameters $\bb{\delta}$ so that they minimize the expression $$
We're using the Discriminator twice in this expression; once to classify data from the real data distribution and once again to classify generated data. Therefore our loss should be computed based on these two terms. Notice that since the discriminator returns a probability, we can formulate the above as two cross-entropy losses.
GANs are notoriously diffucult to train. One common trick for improving GAN stability during training is to make the classification labels noisy for the discriminator. This can be seen as a form of regularization, to help prevent the discriminator from overfitting.
We'll incorporate this idea into our loss function. Instead of labels being equal to 0 or 1, we'll make them "fuzzy", i.e. random numbers in the ranges $[0\pm\epsilon]$ and $[1\pm\epsilon]$.
TODO: Implement the discriminator_loss_fn() function in the hw3/gan.py module.
from hw3.gan import discriminator_loss_fn
torch.manual_seed(42)
y_data = torch.rand(10) * 10
y_generated = torch.rand(10) * 10
loss = discriminator_loss_fn(y_data, y_generated, data_label=1, label_noise=0.3)
print(loss)
test.assertAlmostEqual(loss.item(), 6.4808731, delta=1e-5)
tensor(6.4809)
Similarly, the generator's parameters $\bb{\gamma}$ should minimize the expression $$ -\mathbb{E} _{\bb{z} \sim p(\bb{Z}) } \log (\Delta _{\bb{\delta}}(\Psi _{\bb{\gamma}} (\bb{z}) )) $$
which can also be seen as a cross-entropy term. This corresponds to "fooling" the discriminator; Notice that the gradient of the loss w.r.t $\bb{\gamma}$ using this expression also depends on $\bb{\delta}$.
TODO: Implement the generator_loss_fn() function in the hw3/gan.py module.
from hw3.gan import generator_loss_fn
torch.manual_seed(42)
y_generated = torch.rand(20) * 10
loss = generator_loss_fn(y_generated, data_label=1)
print(loss)
test.assertAlmostEqual(loss.item(), 0.0222969, delta=1e-3)
tensor(0.0223)
Sampling from a GAN is straightforward, since it learns to generate data from an isotropic Gaussian latent space distribution.
There is an important nuance however. Sampling is required during the process of training the GAN, since we generate fake images to show the discriminator. As you'll seen in the next section, in some cases we'll need our samples to have gradients (i.e., to be part of the Generator's computation graph).
TODO: Implement the sample() method in the Generator class within the hw3/gan.py module.
samples = gen.sample(5, with_grad=False)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNone(samples.grad_fn)
_ = plot.tensors_as_images(samples.cpu())
samples = gen.sample(5, with_grad=True)
test.assertSequenceEqual(samples.shape, (5, *x0.shape[1:]))
test.assertIsNotNone(samples.grad_fn)
Training GANs is a bit different since we need to train two models simultaneously, each with it's own separate loss function and optimizer. We'll implement the training logic as a function that handles one batch of data and updates both the discriminator and the generator based on it.
As mentioned above, GANs are considered hard to train. To get some ideas and tips you can see this paper, this list of "GAN hacks" or just do it the hard way :)
TODO:
train_batch function in the hw3/gan.py module.part3_gan_hyperparams() function within the hw3/answers.py module.import torch.optim as optim
from torch.utils.data import DataLoader
from hw3.answers import part3_gan_hyperparams
torch.manual_seed(42)
# Hyperparams
hp = part3_gan_hyperparams()
batch_size = hp['batch_size']
z_dim = hp['z_dim']
# Data
dl_train = DataLoader(ds_gwb, batch_size, shuffle=True)
im_size = ds_gwb[0][0].shape
# Model
dsc = gan.Discriminator(im_size).to(device)
gen = gan.Generator(z_dim, featuremap_size=4).to(device)
# Optimizer
def create_optimizer(model_params, opt_params):
opt_params = opt_params.copy()
optimizer_type = opt_params['type']
opt_params.pop('type')
return optim.__dict__[optimizer_type](model_params, **opt_params)
dsc_optimizer = create_optimizer(dsc.parameters(), hp['discriminator_optimizer'])
gen_optimizer = create_optimizer(gen.parameters(), hp['generator_optimizer'])
# Loss
def dsc_loss_fn(y_data, y_generated):
return gan.discriminator_loss_fn(y_data, y_generated, hp['data_label'], hp['label_noise'])
def gen_loss_fn(y_generated):
return gan.generator_loss_fn(y_generated, hp['data_label'])
# Training
checkpoint_file = 'checkpoints/gan'
checkpoint_file_final = f'{checkpoint_file}_final'
if os.path.isfile(f'{checkpoint_file}.pt'):
os.remove(f'{checkpoint_file}.pt')
# Show hypers
print(hp)
{'batch_size': 8, 'z_dim': 100, 'data_label': 1, 'label_noise': 0.2, 'discriminator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.999)}, 'generator_optimizer': {'type': 'Adam', 'lr': 0.0002, 'betas': (0.5, 0.999)}}
TODO:
save_checkpoint function in the hw3.gan module. You can decide on your own criterion regarding whether to save a checkpoint at the end of each epoch._final. When you run the main.py script to generate your submission, the final checkpoints file will be loaded instead of running training. Note that your final submission zip will not include the checkpoints/ folder. This is OK.import IPython.display
import tqdm
from hw3.gan import train_batch, save_checkpoint
num_epochs = 100
if os.path.isfile(f'{checkpoint_file_final}.pt'):
print(f'*** Loading final checkpoint file {checkpoint_file_final} instead of training')
num_epochs = 0
gen = torch.load(f'{checkpoint_file_final}.pt', map_location=device,)
checkpoint_file = checkpoint_file_final
try:
dsc_avg_losses, gen_avg_losses = [], []
for epoch_idx in range(num_epochs):
# We'll accumulate batch losses and show an average once per epoch.
dsc_losses, gen_losses = [], []
print(f'--- EPOCH {epoch_idx+1}/{num_epochs} ---')
with tqdm.tqdm(total=len(dl_train.batch_sampler), file=sys.stdout) as pbar:
for batch_idx, (x_data, _) in enumerate(dl_train):
x_data = x_data.to(device)
dsc_loss, gen_loss = train_batch(
dsc, gen,
dsc_loss_fn, gen_loss_fn,
dsc_optimizer, gen_optimizer,
x_data)
dsc_losses.append(dsc_loss)
gen_losses.append(gen_loss)
pbar.update()
dsc_avg_losses.append(np.mean(dsc_losses))
gen_avg_losses.append(np.mean(gen_losses))
print(f'Discriminator loss: {dsc_avg_losses[-1]}')
print(f'Generator loss: {gen_avg_losses[-1]}')
if save_checkpoint(gen, dsc_avg_losses, gen_avg_losses, checkpoint_file):
print(f'Saved checkpoint.')
samples = gen.sample(5, with_grad=False)
fig, _ = plot.tensors_as_images(samples.cpu(), figsize=(6,2))
IPython.display.display(fig)
plt.close(fig)
except KeyboardInterrupt as e:
print('\n *** Training interrupted by user')
--- EPOCH 1/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:39<00:00, 1.48s/it] Discriminator loss: 0.2226154575986204 Generator loss: 10.399013006865088
--- EPOCH 2/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:51<00:00, 1.66s/it] Discriminator loss: 0.32406483643424155 Generator loss: 10.796153659251198
--- EPOCH 3/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:40<00:00, 1.50s/it] Discriminator loss: 0.6249632680538431 Generator loss: 7.866348905349845
--- EPOCH 4/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:53<00:00, 1.70s/it] Discriminator loss: 0.7597085722346804 Generator loss: 4.140996360956733
--- EPOCH 5/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:04<00:00, 1.86s/it] Discriminator loss: 0.7992192764780415 Generator loss: 4.1413471245053985
--- EPOCH 6/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:41<00:00, 2.40s/it] Discriminator loss: 0.760242141227224 Generator loss: 3.8490008097976003
--- EPOCH 7/100 --- 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:25<00:00, 1.27s/it] Discriminator loss: 0.6488763174014305 Generator loss: 4.048653104412022
--- EPOCH 8/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.692744659287716 Generator loss: 4.1238828772929175
--- EPOCH 9/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.7140244154342964 Generator loss: 3.651231741727288
--- EPOCH 10/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.10s/it] Discriminator loss: 0.6570371590109904 Generator loss: 4.1205917472269995
--- EPOCH 11/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:17<00:00, 1.16s/it] Discriminator loss: 0.7654145942695105 Generator loss: 3.452788523773649
--- EPOCH 12/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.10s/it] Discriminator loss: 0.6773354987155146 Generator loss: 3.9122310796780373
--- EPOCH 13/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.686968584558857 Generator loss: 3.6455368479685997
--- EPOCH 14/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.6736044735272428 Generator loss: 3.794376923077142
--- EPOCH 15/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:21<00:00, 1.21s/it] Discriminator loss: 0.6158740751778902 Generator loss: 3.928209099306989
--- EPOCH 16/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.7055960969248815 Generator loss: 3.748022351691972
--- EPOCH 17/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.7214510603182351 Generator loss: 4.0355255336903815
--- EPOCH 18/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.5925487618504176 Generator loss: 3.592918805222013
--- EPOCH 19/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.6489447415319841 Generator loss: 3.7618876855764816
--- EPOCH 20/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.6137689489926865 Generator loss: 3.728848683300303
--- EPOCH 21/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:20<00:00, 1.20s/it] Discriminator loss: 0.533989105373621 Generator loss: 4.0369778053084415
--- EPOCH 22/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.6231796376184741 Generator loss: 4.205059199190852
--- EPOCH 23/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.5462197742577809 Generator loss: 3.6763485750155662
--- EPOCH 24/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00, 1.12s/it] Discriminator loss: 0.5820786693870131 Generator loss: 3.934920588536049
--- EPOCH 25/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.48068115604457573 Generator loss: 3.8703388989861334
--- EPOCH 26/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.5212051019406141 Generator loss: 4.146393323122566
--- EPOCH 27/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.47505673365806467 Generator loss: 4.043247221121147
--- EPOCH 28/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.48975709870235246 Generator loss: 3.9668766687165444 Saved checkpoint.
--- EPOCH 29/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.12s/it] Discriminator loss: 0.41483799612789013 Generator loss: 4.285565899379217
--- EPOCH 30/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.47531898019473945 Generator loss: 4.09228524165367
--- EPOCH 31/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.4484067074676503 Generator loss: 4.203467078173339
--- EPOCH 32/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.4615772899184654 Generator loss: 4.129581930032417
--- EPOCH 33/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.34161930709187666 Generator loss: 4.388808421234586
--- EPOCH 34/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.4297226559259553 Generator loss: 4.408588472586959
--- EPOCH 35/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.38908996421899367 Generator loss: 4.403439705051593
--- EPOCH 36/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.5031480472812901 Generator loss: 4.302793388046435
--- EPOCH 37/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.3036700967532485 Generator loss: 4.256853963012126
--- EPOCH 38/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.44352790401942693 Generator loss: 4.935855774737116
--- EPOCH 39/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00, 1.13s/it] Discriminator loss: 0.34675667462731474 Generator loss: 4.575066045149049
--- EPOCH 40/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.40445677874915636 Generator loss: 4.601134561780674
--- EPOCH 41/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.35126391709295673 Generator loss: 4.461568499678996
--- EPOCH 42/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.31365142673698826 Generator loss: 4.84396594140067
--- EPOCH 43/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.33251275033203526 Generator loss: 4.469906568527222
--- EPOCH 44/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.279036117523019 Generator loss: 4.753364780055943
--- EPOCH 45/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.42701762508767754 Generator loss: 5.096534817966063
--- EPOCH 46/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.40647909551191685 Generator loss: 4.550182789119322
--- EPOCH 47/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.24964745699970134 Generator loss: 4.565602247394732
--- EPOCH 48/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.09s/it] Discriminator loss: 0.26656937849388196 Generator loss: 4.488548634657219
--- EPOCH 49/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.2868878970951287 Generator loss: 5.13614484089524
--- EPOCH 50/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.2893540797504916 Generator loss: 4.9686783374245485
--- EPOCH 51/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.4205938978537695 Generator loss: 5.042785594712442
--- EPOCH 52/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.39453655745444904 Generator loss: 5.26520986521422
--- EPOCH 53/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.3967029097587315 Generator loss: 4.636836155137019
--- EPOCH 54/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.2517954114435324 Generator loss: 4.709473912395648
--- EPOCH 55/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.21090908240137704 Generator loss: 4.746104042921493
--- EPOCH 56/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.10s/it] Discriminator loss: 0.2222719030704961 Generator loss: 5.1047546383160265
--- EPOCH 57/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.09s/it] Discriminator loss: 0.2371100701892109 Generator loss: 5.161920148934891
--- EPOCH 58/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.23250969593871884 Generator loss: 5.226165280413272
--- EPOCH 59/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.10s/it] Discriminator loss: 0.1713197364735959 Generator loss: 5.229283902182508
--- EPOCH 60/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.3744734670244046 Generator loss: 5.22563999802319
--- EPOCH 61/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00, 1.13s/it] Discriminator loss: 0.21818367603109845 Generator loss: 5.034501529451626
--- EPOCH 62/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.10s/it] Discriminator loss: 0.2961589781595255 Generator loss: 5.973364061384059
--- EPOCH 63/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.09s/it] Discriminator loss: 0.24063466767321773 Generator loss: 4.941025922547525
--- EPOCH 64/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:16<00:00, 1.14s/it] Discriminator loss: 0.23336398223442817 Generator loss: 4.79269886728543 Saved checkpoint.
--- EPOCH 65/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.2037594327564115 Generator loss: 5.494116355234118
--- EPOCH 66/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.28042146480127944 Generator loss: 5.429578503566002
--- EPOCH 67/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.2380366637857992 Generator loss: 5.586613768961892
--- EPOCH 68/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.24321508038538828 Generator loss: 5.506548023935574
--- EPOCH 69/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.19822439951683157 Generator loss: 5.6423968984119925
--- EPOCH 70/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.19567888513652246 Generator loss: 5.2913757936278385
--- EPOCH 71/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.216128641795089 Generator loss: 5.444126164735253
--- EPOCH 72/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.1359927723768042 Generator loss: 5.344446449137446
--- EPOCH 73/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.08s/it] Discriminator loss: 0.20982019202922708 Generator loss: 5.672997513813759
--- EPOCH 74/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.16212283858834808 Generator loss: 5.464307156961356
--- EPOCH 75/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.36618489243868574 Generator loss: 6.051519253360691
--- EPOCH 76/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.17745266220907666 Generator loss: 4.925476590199257
--- EPOCH 77/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.06s/it] Discriminator loss: 0.18689542805859402 Generator loss: 5.445575205247794
--- EPOCH 78/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.05s/it] Discriminator loss: 0.1568165964155055 Generator loss: 5.459341437069337
--- EPOCH 79/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.16645404059829108 Generator loss: 5.654817463746712
--- EPOCH 80/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:09<00:00, 1.04s/it] Discriminator loss: 0.24319110466028326 Generator loss: 5.824600757057987
--- EPOCH 81/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:19<00:00, 1.18s/it] Discriminator loss: 0.1486799336080231 Generator loss: 5.589747959108495
--- EPOCH 82/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:28<00:00, 1.32s/it] Discriminator loss: 0.1359346567917226 Generator loss: 5.59498687644503
--- EPOCH 83/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:16<00:00, 1.14s/it] Discriminator loss: 0.3582810720623429 Generator loss: 6.1349239803072235
--- EPOCH 84/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.09s/it] Discriminator loss: 0.21791031133772723 Generator loss: 5.550275245709206
--- EPOCH 85/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:12<00:00, 1.09s/it] Discriminator loss: 0.1460703068982754 Generator loss: 5.643578984844151
--- EPOCH 86/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.1680278676881719 Generator loss: 5.423344418184081
--- EPOCH 87/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.05s/it] Discriminator loss: 0.13634364540452387 Generator loss: 6.287931374649503
--- EPOCH 88/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:10<00:00, 1.05s/it] Discriminator loss: 0.15707427296620696 Generator loss: 6.208204162654592
--- EPOCH 89/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00, 1.13s/it] Discriminator loss: 0.23691063135195134 Generator loss: 6.708966758713793
--- EPOCH 90/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:18<00:00, 1.17s/it] Discriminator loss: 0.1935963678215422 Generator loss: 5.781576846962545
--- EPOCH 91/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:15<00:00, 1.13s/it] Discriminator loss: 0.1911312595232209 Generator loss: 5.672198094538788 Saved checkpoint.
--- EPOCH 92/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:20<00:00, 1.20s/it] Discriminator loss: 0.11797445981916208 Generator loss: 5.658509987503735 Saved checkpoint.
--- EPOCH 93/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:13<00:00, 1.10s/it] Discriminator loss: 0.06029647434436118 Generator loss: 5.692577867365595
--- EPOCH 94/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.21785901588346088 Generator loss: 5.82088153397859
--- EPOCH 95/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.07s/it] Discriminator loss: 0.11724033395745861 Generator loss: 5.538199930048701
--- EPOCH 96/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:11<00:00, 1.06s/it] Discriminator loss: 0.12432931063335333 Generator loss: 6.089405600704364
--- EPOCH 97/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:14<00:00, 1.11s/it] Discriminator loss: 0.1921722201991882 Generator loss: 6.962398450766036
--- EPOCH 98/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [01:58<00:00, 1.76s/it] Discriminator loss: 0.12687940124088704 Generator loss: 6.596022196670077
--- EPOCH 99/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:23<00:00, 2.14s/it] Discriminator loss: 0.2868398177757192 Generator loss: 6.26818527392487
--- EPOCH 100/100 --- 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:32<00:00, 2.28s/it] Discriminator loss: 0.1492465426450345 Generator loss: 6.200201539850947
# Plot images from best or last model
if os.path.isfile(f'{checkpoint_file}.pt'):
gen = torch.load(f'{checkpoint_file}.pt', map_location=device)
print('*** Images Generated from best model:')
samples = gen.sample(n=15, with_grad=False).cpu()
fig, _ = plot.tensors_as_images(samples, nrows=3, figsize=(6,6))
*** Images Generated from best model:
TODO Answer the following questions. Write your answers in the appropriate variables in the module hw3/answers.py.
from cs236781.answers import display_answer
import hw3.answers as answers
Explain in detail why during training we sometimes need to maintain gradients when sampling from the GAN, and other times we don't. When are they maintained and why? When are they discarded and why?
display_answer(answers.part3_q1)
Your answer: In the training phase, we only train the discriminator alone, and in this phase we sample the images accordingly. We don't want these samples to affect the gradient of the generator, so we need to separate these samples from the backpropagation process. This can happen even if we don't intend to. Therefore, when we train the generator and freeze the discriminator, we preserve these gradients to improve the sampling power of the generator.
When training a GAN to generate images, should we decide to stop training solely based on the fact that the Generator loss is below some threshold? Why or why not?
What does it mean if the discriminator loss remains at a constant value while the generator loss decreases?
display_answer(answers.part3_q2)
Your answer: 1) We shouldn't decide to stop training just because the generator loss is below a certain threshold, because if we look at the results, we can see that a low loss rate doesn't mean that the generator produces sound images=. Loss is defined by the ability of the discriminator to detect fake images, it does not measure sample quality. Sometimes the discriminator is not very good and the generator produces bad samples, but these samples can fool the discriminator.
2) If the discriminator loss remains constant and the generator loss decreases, it means that the discriminator cannot correctly identify real and fake samples. Generator improved and created better samples.
Compare the results you got when generating images with the VAE to the GAN results. What's the main difference and what's causing it?
display_answer(answers.part3_q3)
Your answer: It can be said that the images we generate with VAE are smoother and more focused on human faces. If we compare it to the VAE, those generated by the GAN are more noisy and have multiple colors. This might be due to the differences in architecture and loss function between both networks. For example, if we compare the loss functions of these two: the VAE loss function is directly related to the dataset, unlike the GAN loss function, it is from a game theory perspective and has no direct relationship with the dataset, so the general picture related refers to the entire image, including the background and its colors. In the VAE dataset, we have a common face, and because of its architecture and care for mutual information in the input and decoded images, it preserves the common features in the resulting decoded images without preserving the background and its color .